import tqdm
import numpy as np
import tensorflow as tf
import pickle
import sys
import os
from data_loader.unet_data_generator import DataGenerator



class UNET_Trainer:

    def __init__(self, model, data,valdata, config):

        self.model = model
        self.config = config
        self.data = data
        self.valdata = valdata

        self.train_loss = tf.keras.metrics.Mean(name='train_loss')
        self.val_loss = tf.keras.metrics.Mean(name='val_loss')
        
    def custom_loss(self,targets,jx_pred,jy_pred):

        jx_pred=tf.reshape(jx_pred,[self.config.batch_size,self.config.im_h,self.config.im_w])
        mse_jx=tf.reduce_mean(tf.square(targets[:,:,:,0]-jx_pred))

       
        jy_pred=tf.reshape(jy_pred,[self.config.batch_size,self.config.im_h,self.config.im_w])
        mse_jy=tf.reduce_mean(tf.square(targets[:,:,:,1]-jy_pred))
        train_loss_mse_reconstruction=(mse_jx+mse_jy)/2
       
       
        
        
        return train_loss_mse_reconstruction

    def valcustom_loss(self,targets,jx_pred,jy_pred):

        jx_pred=tf.reshape(jx_pred,[self.config.batch_size,self.config.im_h,self.config.im_w])
        mse_jx=tf.reduce_mean(tf.square(targets[:,:,:,0]-jx_pred))

       
        jy_pred=tf.reshape(jy_pred,[self.config.batch_size,self.config.im_h,self.config.im_w])
        mse_jy=tf.reduce_mean(tf.square(targets[:,:,:,1]-jy_pred))
        val_loss=(mse_jx+mse_jy)/2
       

        return val_loss
        
    def train_step(self, epoch, optimizer):

        raw_data, targets = next(self.data.next_batch(self.config.batch_size))
        
        cprob = 1  # multiplicative noise on (default during training)

        raw_data_input = tf.math.multiply(raw_data, tf.random.uniform(shape=tf.shape(raw_data), minval=0.99,
                                                                      maxval=1.01)) * cprob + raw_data * (1 - cprob)

        with tf.GradientTape() as tape:
            J_pred = self.model(raw_data_input, training=False)
            jx_pred=J_pred[:,:,:,0]
            jy_pred=J_pred[:,:,:,1]
            train_loss_mse_reconstruction = self.custom_loss(targets,jx_pred,jy_pred)
            
        gradients = tape.gradient(train_loss_mse_reconstruction, self.model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
        
        self.train_loss = train_loss_mse_reconstruction
        
        return self.train_loss

    def val_step(self, epoch):
        raw_valdata, valtargets = next(self.valdata.next_batch(self.config.batch_size))
        J_pred = self.model(raw_valdata, training=False)
        jx_pred=J_pred[:,:,:,0]
        jy_pred=J_pred[:,:,:,1]
        valloss = self.valcustom_loss(valtargets,jx_pred,jy_pred)
        
        self.val_loss = valloss
        return self.val_loss


    def train(self):

        loss_training = np.zeros((2,self.config.num_epochs*self.config.num_files))

        optimizer = tf.keras.optimizers.RMSprop(learning_rate=self.config.learning_rate)

        for epoch in range(self.config.num_epochs):
            for file_num in range(self.config.num_files):
                
                self.data=DataGenerator(self.config,file_num+1)

                pbar = tqdm.tqdm(total=self.data.len // self.config.batch_size, desc='Steps', position=0)
                train_status = tqdm.tqdm(total=0, bar_format='{desc}', position=1)
                
    
                for step in range(self.data.len // self.config.batch_size):
                    loss = self.train_step(epoch, optimizer)
                    valloss = self.val_step(epoch)
                    
                    train_status.set_description_str(f'Epoch: {epoch} Loss: {self.train_loss} Val Loss: {self.val_loss}')
                    
                    pbar.update()
    
                template = 'Epoch {}, File Number {},Loss: {}, ValLoss: {}'
                print(template.format(epoch, file_num,self.train_loss, self.val_loss))
            
                loss_training[0, epoch*self.config.num_files+file_num] = self.train_loss
                loss_training[1, epoch*self.config.num_files+file_num] = self.val_loss
                
            self.model.save(self.config.checkpoint_dir+"model.keras")
        
        file0 = open(self.config.checkpoint_dir+'train_loss', "a")  
        file1 = open(self.config.checkpoint_dir+'val_loss', "a")  
        file0.write(str(self.train_loss)+',')
        file1.write(str(self.val_loss)+',')
        file0.close()
        file1.close()
   
            # To save a different model/checkpoint at each epoch (will take up a lot more disk space!):
            # self.model.save(os.path.join(self.config.checkpoint_dir,str(epoch)+'.h5'))

        with open(self.config.graph_file, 'wb') as f:
            np.save(f, loss_training)